1 Purpose - redo of 100k epochs analysis

trying to figure out stopping criteria to use. Based on 10 simulations of basic linear regression horseshoe architecture (i.e. only 1 neuron).

Data generation:

  • 4 covariates truly related to outcome (linearly related), 100 nuisance variables. Covariate values all generated from N(0, 1) via torch_randn(). N(0,1) noise

    • i.e. \(y = -0.5 x_1 + 1x_2 -2 x_3 + 4 x_4 + \epsilon\)
  • 100 observations in training set, 25 in test set

  • true coefficients: -0.5, 1, -2, 4, 0, 0, 0, 0, 0 …. (4 non-zero coefficients, 100 0 coefficients)

  • only stopping criteria employed was max number of epochs (training epochs = 250k)

# combine results files
fname_stem <- here::here("sims", "results", "hshoe_linreg_maxepochs_250k")
sim_seeds <- c()
res_comp <- list()

for (i in 1:8){
  fname <- paste0(fname_stem, i, ".Rdata")
  load(fname)
  sim_seeds <- c(sim_seeds, contents$sim_params$sim_seeds)
  res_comp <- append(res_comp, contents$res)
}

res_comp <- setNames(res_comp, paste0("sim_", 1:length(res_comp)))
final_alphas <- t(sapply(res_comp,
       function(X) X$alpha_mat[nrow(X$alpha_mat),]
       ))

2 Mainline results: 250k training epochs

2.1 Using naive threshold of 0.05

Criteria for model inclusion: is the “dropout rate parameter” \(\alpha\) lower than our T1 error threshold. This interprets \(\alpha\) as a probability, which is not great (\(\alpha\) is commonly > 1 for the nuisance variables)

2.1.1 Type II errors:

Total count of Type II errors and T2 error rate over 200 simulated datasets:

# count
sum(final_alphas[, 1:4] > 0.05)
## [1] 95
# Type II error rate:
mean(final_alphas[, 1:4] > 0.05)
## [1] 0.11875

All errors are from the first covariate (\(\beta = 0.5\)). Table below shows \(\alpha\) for first 4 covariates for all 200 simulated datasets.

2.1.2 T1 error

No spurious variables chosen when setting alpha threshold at 0.05.

# count alphas < 0.05 among nuisance variables
sum(final_alphas[, 5:104] < 0.05)
## [1] 0

2.2 Interpreting \(\alpha\) as posterior-based Wald statistic

A more principled approach might be to compare the dropout parameter \(\alpha\) against the inverse of a \(\chi^2\) distribution with 1 degree of freedom and applying Bonferroni. This is based on the idea that

  1. (to justify using \(\chi^2 (1)\) distribution) the \(\alpha\) parameter is the inverse of the posterior-based Wald statistic discussed in Liu, Li, Yu 2020 (referred to as LLY 2020);

  2. (to justify Bonferroni) the mean-field assumption used in variational inference assumes independence between the individual \(\alpha\) parameters (\(\alpha_i = \dfrac{Var(\tilde{z_i})}{ \left[ E(\tilde{z_i}) \right] ^2}\).

2.2.1 LLY 2020

LLY 2020 propose the posterior-based Wald statistic

\[W = (\bar\theta - \theta_0)'[V_{\theta \theta} (\bar\nu)]^{-1} (\bar\theta - \theta_0) \overset{d}{\rightarrow} \chi^2(q_\theta)\]

  • \(\theta\) is the parameter(s) of interest (\(\bar\theta\) refers to the posterior mean),

  • \(q_\theta\) the dimension of \(\theta\),

  • \(V_{\theta \theta} (\bar\nu)\) the portion of the posterior covariance matrix relevant to \(\theta\) (\(\nu\) refers to all estimated parameters),

2.2.2 Results for posterior-based Wald interpretation of \(\alpha\)

2.2.2.1 Type II error

Type II error count & rate:

wald_thresh <- 1 / qchisq(1 - (0.05 / 104), df = 1)
t2_sum <- sum(final_alphas[, 1:4] > wald_thresh)
t2_sum # count T2 errors
## [1] 39
mean(final_alphas[, 1:4] > wald_thresh) # rate
## [1] 0.04875

Results: 39 errors total out of a possible 800 (4 true covariates, 200 simulations)

I.e. T2 error rate of 0.04875.

2.2.2.2 Type I error

T1 error count & rate (out of 100 nuisance variables * 200 simulations):

sum(final_alphas[, 5:104] < wald_thresh)
## [1] 4
mean(final_alphas[, 5:104] < wald_thresh)
## [1] 0.0002

so a T1 error rate of 0.0002

2.2.3 Problem with posterior Wald interpretation?

Below is a histogram of the \(\alpha\) parameters, for the 100 nuisance parameters, appearing in the last training epoch of the 100 simulations, followed by a histogram of 1000 draws from a \(\chi^2(1)\).

hist(1/final_alphas[, 5:104], xlim = c(0, 5), breaks = 250)

hist(rchisq(1000, df=1), xlim = c(0, 5), breaks = 250)

The two do not match well. However, there are a few possible explanations:

  1. variational inference is known to underestimate variance (which would push the mode to the right);

  2. maybe explaining why there are NO values below 0.7: in LLY 2020, Theorem 3.1, it is clarified that the Bayesian \(W\) and the Frequentist \(Wald\) are not quite the same:

\[W = Wald + o_p(1) \overset{d}{\rightarrow} \chi^2(q_\theta)\]

Last, we are just looking at the final training epochs. These simulations simply set a maximum number of 100k training epochs, so it’s possible that the network should have trained for a longer period of time.

2.2.4 notes / thoughts:

  • any way to get ELBOs for competing models to approximate bayes factors? Is this even useful / desirable? Avoiding this kind of computation is kind of the reason we’re using NN’s in the first place….

3 Examination for reducing FP / FN?

Looking at different stopping criteria for training the network:

  1. “convergence”
  • in NEXT analysis (the results so far save results every 1000 epochs, not every epoch)

  • of test MSE

  • of train MSE

  • of alphas

  1. test, train, test - train MSE increasing
  • particularly with test MSE > train MSE
  1. rolling mean of the above,

3.1 test, train, test-train MSE increasing

# did simulation have false positive / false negative?
FP_mat <- final_alphas[, 5:104] < wald_thresh
FN_mat <- final_alphas[, 1:4] > wald_thresh
err_df <- data.frame(
  "FP" = apply(FP_mat, 1, any),
  "FN" = apply(FN_mat, 1, any),
  "sim" = paste0("sim_", 1:nrow(FP_mat))
)

err_df$err_type <- ifelse(err_df$FN, "t2" , NA)
err_df$err_type <- ifelse(err_df$FP, "t1" , err_df$err_type)
err_df$err_type <- ifelse(err_df$FP + err_df$FN == 0, "none" , err_df$err_type)

# compile loss_mats from all simulations
res_arr <- sapply(
  res_comp,
  function(X) cbind(X$alpha_mat, X$loss_mat),
  simplify = "array"
)

# loss_arr dims:   1: epoch;   2: kl, mse_train, mse_test;   3: sim_#
loss_arr <- sapply(
  res_comp,
  function(X) X$loss_mat,
  simplify = "array"
)
# alph_arr dims:   1: epoch;   2: coefs   3: sim_#
alph_arr <- sapply(
  res_comp,
  function(X) X$alpha_mat,
  simplify = "array"
)

if (length(unique(sapply(res_arr, nrow))) > 1) {
  res_arr <- res_arr[11:100]
  res_arr <- simplify2array(res_arr)
  loss_arr <- loss_arr[11:100]
  loss_arr <- simplify2array(loss_arr)
  alph_arr <- alph_arr[11:100]
  alph_arr <- simplify2array(alph_arr)
  err_df <- err_df[11:100, ]
}

dimnames(res_arr)[[2]] <- c(
  paste0("alpha0", 1:9),
  paste0("alpha", 10:104),
  "kl", "mse_train", "mse_test" 
)

# colnames(res_arr[, , 1])[105:107]
kl_mat <- res_arr[, 105, ]
train_mat <- res_arr[, 106, ]
test_mat <- res_arr[, 107, ]
tetr_mat <- test_mat - train_mat

3.1.1 train MSE plot

ma_k <- 5

train_MA_df <- data.frame(
  apply(
    loss_arr[, 2, ], 
    2, 
    function(X) zoo::rollmean(X, k = ma_k)
  )
)

train_df <- data.frame(train_mat)
train_df$epoch <- as.numeric(rownames(train_mat))
train_MA_df$epoch <- train_df$epoch[ma_k : length(train_df$epoch)]

train_df_long <- train_df %>%
  pivot_longer(
    cols = !epoch,
    names_to = "sim",
    values_to = "mse_train")
train_MA_df_long <- train_MA_df %>%
  pivot_longer(
    cols = !epoch,
    names_to = "sim",
    values_to = "MA_mse_train")

train_df_long <- inner_join(train_df_long, err_df, by = "sim")
train_MA_df_long <- inner_join(train_MA_df_long, err_df, by = "sim")

# sample sims to be able to see
display_sims <- paste0("sim_", sample(1:100, 25))


train_df_long %>%
  # filter(err_type != "none") %>% 
  filter(sim %in% display_sims) %>% 
  filter(epoch > 70000) %>% 
  ggplot(
    aes(
      y = mse_train, 
      x = epoch, 
      group = sim,
      color = sim)
  ) + 
  geom_line(alpha = 0.5) + 
  facet_wrap(~err_type, nrow = 3) + 
  theme(legend.position = "none") + 
  labs(
    title = "train_MSE by error type (70-250k epochs)"
  )

train_MA_df_long %>%
  # filter(err_type != "none") %>% 
  filter(sim %in% display_sims) %>% 
  filter(epoch > 70000) %>% 
  ggplot(
    aes(
      y = MA_mse_train, 
      x = epoch, 
      group = sim,
      color = sim)
  ) + 
  geom_line(alpha = 0.5) + 
  facet_wrap(~err_type, nrow = 3) + 
  theme(legend.position = "none") + 
  labs(
    title = paste0(
      "moving average (",
      ma_k,
      ") train_MSE by error type (70-250k epochs)"
    )
  )

3.1.2 test MSE

test_MA_df <- data.frame(
  apply(
    loss_arr[, 3, ], 
    2, 
    function(X) zoo::rollmean(X, k = ma_k)
  )
)

test_df <- data.frame(test_mat)
test_df$epoch <- as.numeric(rownames(test_mat))
test_MA_df$epoch <- test_df$epoch[ma_k : length(test_df$epoch)]

test_df_long <- test_df %>%
  pivot_longer(
    cols = !epoch,
    names_to = "sim",
    values_to = "mse_test")
test_MA_df_long <- test_MA_df %>%
  pivot_longer(
    cols = !epoch,
    names_to = "sim",
    values_to = "MA_mse_test")

test_df_long <- inner_join(test_df_long, err_df, by = "sim")
test_MA_df_long <- inner_join(test_MA_df_long, err_df, by = "sim")

test_df_long %>%
  # filter(err_type != "none") %>% 
  filter(epoch > 70000) %>% 
  ggplot(
    aes(
      y = mse_test, 
      x = epoch, 
      group = sim,
      color = sim)
  ) + 
  geom_line(alpha = 0.5) + 
  facet_wrap(~err_type, nrow = 3) + 
  theme(legend.position = "none") + 
  labs(
    title = "test_MSE by error type (70-250k epochs)"
  )

test_MA_df_long %>%
  # filter(err_type != "none") %>% 
  filter(epoch > 70000) %>% 
  ggplot(
    aes(
      y = MA_mse_test, 
      x = epoch, 
      group = sim,
      color = sim)
  ) + 
  geom_line(alpha = 0.5) + 
  facet_wrap(~err_type, nrow = 3) + 
  theme(legend.position = "none") + 
  labs(
    title = paste0(
      "moving average (",
      ma_k,
      ") test_MSE by error type (70-250k epochs)"
    )
  )

3.1.3 test - train

tetr_MA_df <- data.frame(
  apply(
    loss_arr[, 3, ] - loss_arr[, 2, ], 
    2, 
    function(X) zoo::rollmean(X, k = ma_k)
  )
)

tetr_df <- data.frame(tetr_mat)
tetr_df$epoch <- as.numeric(rownames(tetr_mat))
tetr_MA_df$epoch <- tetr_df$epoch[ma_k : length(tetr_df$epoch)]

tetr_df_long <- tetr_df %>%
  pivot_longer(
    cols = !epoch,
    names_to = "sim",
    values_to = "mse_tetr")
tetr_MA_df_long <- tetr_MA_df %>%
  pivot_longer(
    cols = !epoch,
    names_to = "sim",
    values_to = "MA_mse_tetr")

tetr_df_long <- inner_join(tetr_df_long, err_df, by = "sim")
tetr_MA_df_long <- inner_join(tetr_MA_df_long, err_df, by = "sim")

tetr_df_long %>%
  # filter(err_type != "none") %>% 
  filter(epoch > 70000) %>% 
  ggplot(
    aes(
      y = mse_tetr, 
      x = epoch, 
      group = sim,
      color = sim)
  ) + 
  geom_line(alpha = 0.5) + 
  facet_wrap(~err_type, nrow = 3) + 
  theme(legend.position = "none") + 
  labs(
    title = "tetr_MSE by error type (70-250k epochs)"
  )

tetr_MA_df_long %>%
  # filter(err_type != "none") %>% 
  filter(epoch > 70000) %>% 
  ggplot(
    aes(
      y = MA_mse_tetr, 
      x = epoch, 
      group = sim,
      color = sim)
  ) + 
  geom_line(alpha = 0.5) + 
  facet_wrap(~err_type, nrow = 3) + 
  theme(legend.position = "none") + 
  labs(
    title = paste0(
      "moving average (",
      ma_k,
      ") tetr_MSE by error type (70-250k epochs)"
    )
  )

3.1.4 KL

kl_MA_df <- data.frame(
  apply(
    loss_arr[, 1, ], 
    2, 
    function(X) zoo::rollmean(X, k = ma_k)
  )
)

kl_df <- data.frame(kl_mat)
kl_df$epoch <- as.numeric(rownames(kl_mat))
kl_MA_df$epoch <- kl_df$epoch[ma_k : length(kl_df$epoch)]

kl_df_long <- kl_df %>%
  pivot_longer(
    cols = !epoch,
    names_to = "sim",
    values_to = "kl")
kl_MA_df_long <- kl_MA_df %>%
  pivot_longer(
    cols = !epoch,
    names_to = "sim",
    values_to = "MA_kl")

kl_df_long <- inner_join(kl_df_long, err_df, by = "sim")
kl_MA_df_long <- inner_join(kl_MA_df_long, err_df, by = "sim")

kl_df_long %>%
  # filter(err_type != "none") %>% 
  filter(epoch > 70000) %>% 
  ggplot(
    aes(
      y = kl, 
      x = epoch, 
      group = sim,
      color = sim)
  ) + 
  geom_line(alpha = 0.5) + 
  facet_wrap(~err_type, nrow = 3) + 
  theme(legend.position = "none") + 
  labs(
    title = "kl by error type (70-250k epochs)"
  )

kl_MA_df_long %>%
  # filter(err_type != "none") %>% 
  filter(epoch > 70000) %>% 
  ggplot(
    aes(
      y = MA_kl, 
      x = epoch, 
      group = sim,
      color = sim)
  ) + 
  geom_line(alpha = 0.5) + 
  facet_wrap(~err_type, nrow = 3) + 
  theme(legend.position = "none") + 
  labs(
    title = paste0(
      "moving average (",
      ma_k,
      ") kl by error type (70-250k epochs)"
    )
  )

3.1.5 conclusions from 100k training epoch sims

KL, train MSE, test MSE not really offering good clues as to T1 / T2 errors, at least looking at results from every 1k epochs for 100k training epochs

3.2 Reverse examination

  • look at point when alphas reach threshold and stop changing decisions
above_w <- apply(alph_arr, c(2, 3), function(X) X>wald_thresh)
dim(above_w)
## [1] 250 104 200
change_w <- apply(above_w, c(2, 3), function(X) diff(X))
change_count_by_epoch <- apply(change_w, c(1, 3), function(X) sum(abs(X)))

rowSums(change_count_by_epoch)
##   2000   3000   4000   5000   6000   7000   8000   9000  10000  11000  12000 
##      0      0      0      0      0      0      0      0      0      0      0 
##  13000  14000  15000  16000  17000  18000  19000  20000  21000  22000  23000 
##      0      0      0      0      0      2     12     26     57     77     90 
##  24000  25000  26000  27000  28000  29000  30000  31000  32000  33000  34000 
##    106     84     69     58     58     36     23     25     17     17     11 
##  35000  36000  37000  38000  39000  40000  41000  42000  43000  44000  45000 
##      5      8      9      7      5      4      3      2      1      1      1 
##  46000  47000  48000  49000  50000  51000  52000  53000  54000  55000  56000 
##      0      2      4      6      3      1      2      4      4      4      2 
##  57000  58000  59000  60000  61000  62000  63000  64000  65000  66000  67000 
##      2      2      5      2      1      1      1      4      3      1      2 
##  68000  69000  70000  71000  72000  73000  74000  75000  76000  77000  78000 
##      1      2      2      1      1      0      0      2      1      1      1 
##  79000  80000  81000  82000  83000  84000  85000  86000  87000  88000  89000 
##      4      2      1      0      0      0      1      0      0      0      0 
##  90000  91000  92000  93000  94000  95000  96000  97000  98000  99000  1e+05 
##      1      2      1      0      1      0      1      3      1      1      2 
## 101000 102000 103000 104000 105000 106000 107000 108000 109000 110000 111000 
##      1      1      1      2      0      2      3      1      1      2      6 
## 112000 113000 114000 115000 116000 117000 118000 119000 120000 121000 122000 
##      2      0      2      2      3      5      4      4      4      3      2 
## 123000 124000 125000 126000 127000 128000 129000 130000 131000 132000 133000 
##      3      4      2      2      0      2      1      1      0      4      5 
## 134000 135000 136000 137000 138000 139000 140000 141000 142000 143000 144000 
##      5      3      3      4      3      1      0      1      1      1      1 
## 145000 146000 147000 148000 149000 150000 151000 152000 153000 154000 155000 
##      2      0      2      0      2      2      1      1      0      1      1 
## 156000 157000 158000 159000 160000 161000 162000 163000 164000 165000 166000 
##      2      2      0      1      3      1      0      1      2      3      4 
## 167000 168000 169000 170000 171000 172000 173000 174000 175000 176000 177000 
##      1      0      1      3      3      3      1      2      2      0      1 
## 178000 179000 180000 181000 182000 183000 184000 185000 186000 187000 188000 
##      0      2      0      3      1      1      0      2      2      1      0 
## 189000 190000 191000 192000 193000 194000 195000 196000 197000 198000 199000 
##      1      2      3      1      3      1      1      1      1      1      2 
##  2e+05 201000 202000 203000 204000 205000 206000 207000 208000 209000 210000 
##      0      2      1      0      0      1      1      1      1      1      0 
## 211000 212000 213000 214000 215000 216000 217000 218000 219000 220000 221000 
##      1      1      2      2      3      2      2      0      1      1      4 
## 222000 223000 224000 225000 226000 227000 228000 229000 230000 231000 232000 
##      0      0      1      2      2      2      2      1      0      1      1 
## 233000 234000 235000 236000 237000 238000 239000 240000 241000 242000 243000 
##      0      1      2      1      0      0      3      1      3      2      0 
## 244000 245000 246000 247000 248000 249000 250000 
##      1      0      2      1      0      1      2
sum(change_count_by_epoch[17:81, ])
## [1] 889
sum(change_count_by_epoch)
## [1] 1137

Most crosses of the wald-stat threshold occur between epochs 18k - 81k (889 crossings out of a total of 1137 crossings).

  • What is the behavior after 61k epochs?
# count changes from False Neg to True Post (good)
sum(change_w[61:99, 1:4,] == -1)
## [1] 21
# count changes from True Pos to False Neg (bad)
sum(change_w[61:99, 1:4,] == 1)
## [1] 14
# how many of these changes are the same variable flipping?
varflips <- apply(
  change_w[61:99, 1:4, ], 
  3, 
  function(X) colSums(abs(X))
)
flips <- apply(varflips, 1, function(X) which(X != 0))
flips
## [[1]]
##  sim_14  sim_24  sim_38  sim_75  sim_84  sim_90  sim_98 sim_114 sim_118 sim_149 
##      14      24      38      75      84      90      98     114     118     149 
## sim_158 sim_166 
##     158     166 
## 
## [[2]]
## named integer(0)
## 
## [[3]]
## named integer(0)
## 
## [[4]]
## named integer(0)

So these flips are all occurring in variable 1 (behavior to be expected). Simulations 14, 24, 38, 75, 84, 90, 98, 114, 118, 149, 158, 166

3.2.1 comparison of flips vs noflips

3.2.1.1 train_mse

flip_sims <- paste0("sim_", flips[[1]])
display_sims_noflips <- paste0("sim_", sample(setdiff(1:100, flips[[1]]), 20))

lo_epoch <- 30000
hi_epoch <- 250000
epoch_label <- paste0(lo_epoch/1000, "k-", hi_epoch/1000, "k epochs")

train_df_long %>%
  # filter(err_type != "none") %>% 
  filter(sim %in% flip_sims) %>% 
  filter(epoch > lo_epoch & epoch < hi_epoch) %>% 
  ggplot(
    aes(
      y = mse_train, 
      x = epoch, 
      group = sim,
      color = sim)
  ) + 
  geom_line(alpha = 0.25) + 
  geom_line(stat = "smooth", alpha = 0.5) + 
  facet_wrap(~err_type, nrow = 3) + 
  theme(legend.position = "none") + 
  labs(
    title = paste0("flipsims: train_MSE by error type; ", epoch_label)
  ) 

train_df_long %>%
  # filter(err_type != "none") %>% 
  filter(sim %in% display_sims_noflips) %>% 
  filter(epoch > lo_epoch & epoch < hi_epoch) %>% 
  ggplot(
    aes(
      y = mse_train, 
      x = epoch, 
      group = sim,
      color = sim)
  ) + 
  geom_line(alpha = 0.25) + 
  geom_line(stat = "smooth", alpha = 0.5) + 
  facet_wrap(~err_type, nrow = 3) + 
  theme(legend.position = "none") + 
  labs(
    title = paste0("noflips: train_MSE by error type; ", epoch_label)
  ) 

  • interestingly, all of the Type 1 errors (4) are in the simulations which flip. However, not all of the flipping simulations contained T1 errors at 250k epochs.

3.2.1.2 test_mse

test_df_long %>%
  # filter(err_type != "none") %>% 
  filter(sim %in% flip_sims) %>% 
  filter(epoch > lo_epoch & epoch < hi_epoch) %>% 
  ggplot(
    aes(
      y = mse_test, 
      x = epoch, 
      group = sim,
      color = sim)
  ) + 
  geom_line(alpha = 0.25) + 
  geom_line(stat = "smooth", alpha = 0.5) + 
  facet_wrap(~err_type, nrow = 3) + 
  theme(legend.position = "none") + 
  labs(
    title = paste0("flipsims: test_MSE by error type; ", epoch_label)
  ) 

test_df_long %>%
  # filter(err_type != "none") %>% 
  filter(sim %in% display_sims_noflips) %>% 
  filter(epoch > lo_epoch & epoch < hi_epoch) %>% 
  ggplot(
    aes(
      y = mse_test, 
      x = epoch, 
      group = sim,
      color = sim)
  ) + 
  geom_line(alpha = 0.25) + 
  geom_line(stat = "smooth", alpha = 0.5) + 
  facet_wrap(~err_type, nrow = 3) + 
  theme(legend.position = "none") + 
  labs(
    title = paste0("noflips: test_MSE by error type; ", epoch_label)
  )

Note: test mse appears to level off around 100k epochs.

3.2.1.3 kl

kl_df_long %>%
  # filter(err_type != "none") %>% 
  filter(sim %in% flip_sims) %>% 
  filter(epoch > lo_epoch & epoch < hi_epoch) %>% 
  ggplot(
    aes(
      y = kl, 
      x = epoch, 
      group = sim,
      color = sim)
  ) + 
  geom_line(alpha = 0.25) + 
  facet_wrap(~err_type, nrow = 3) + 
  theme(legend.position = "none") + 
  labs(
    title = paste0("flipsims: kl by error type; ", epoch_label)
  ) 

kl_df_long %>%
  # filter(err_type != "none") %>% 
  filter(sim %in% display_sims_noflips) %>% 
  filter(epoch > lo_epoch & epoch < hi_epoch) %>% 
  ggplot(
    aes(
      y = kl, 
      x = epoch, 
      group = sim,
      color = sim)
  ) + 
  geom_line(alpha = 0.25) + 
  facet_wrap(~err_type, nrow = 3) + 
  theme(legend.position = "none") + 
  labs(
    title = paste0("noflips: kl by error type; ", epoch_label)
  )

3.2.1.4 test - train

tetr_df_long %>%
  # filter(err_type != "none") %>% 
  filter(sim %in% flip_sims) %>% 
  filter(epoch > lo_epoch & epoch < hi_epoch) %>% 
  ggplot(
    aes(
      y = mse_tetr, 
      x = epoch, 
      group = sim,
      color = sim)
  ) + 
  geom_line(alpha = 0.25) + 
  geom_line(stat = "smooth", alpha = 0.5) + 
  facet_wrap(~err_type, nrow = 3) + 
  theme(legend.position = "none") + 
  labs(
    title = paste0("flipsims: test-train MSE by error type; ", epoch_label)
  )

tetr_df_long %>%
  # filter(err_type != "none") %>% 
  filter(sim %in% display_sims_noflips) %>% 
  filter(epoch > lo_epoch & epoch < hi_epoch) %>% 
  ggplot(
    aes(
      y = mse_tetr, 
      x = epoch, 
      group = sim,
      color = sim)
  ) + 
  geom_line(alpha = 0.25) + 
  geom_line(stat = "smooth", alpha = 0.5) + 
  facet_wrap(~err_type, nrow = 3) + 
  theme(legend.position = "none") + 
  labs(
    title = paste0("noflips: test-train MSE by error type; ", epoch_label)
  )

3.3 Conclusions

  • train / test mse seems a bit too noisy to give us clues here as to when alphas stop flipping across the wald threshold.

    • Need to look at every iteration instead of every 1000 iterations

    • or look at the first-order differences directly?

  • in any case, KL looks like the thing to look at — smoothly decreasing.

    • Epsilon used to assess convergence should be scaled by \(n\).
  • overall, training metrics appear fairly stable around 60k epochs

3.3.1 at 60k epochs:

get_alphas_by_row <- function(row_num){
 t(sapply(res_comp,
         function(X) X$alpha_mat[row_num,]
         )) 
}

alphas_60k <- get_alphas_by_row(60)
sum(alphas_60k[, 5:104] < wald_thresh)  # T1
## [1] 4
sum(alphas_60k[, 1:4] > wald_thresh)    # T2
## [1] 53

3.3.2 over epochs

rowdim <- nrow(res_comp[[1]]$loss_mat)
errs_by_epoch <- matrix(NA, nrow = rowdim, ncol = 3)
colnames(errs_by_epoch) <- c("epoch", "T1", "T2")

for (epoch_row in 1:rowdim){
  alphs <- get_alphas_by_row(epoch_row)
  errs_by_epoch[epoch_row, 1] <- epoch_row * 1000
  errs_by_epoch[epoch_row, 2] <- sum(alphs[, 5:104] < wald_thresh)  # FP
  errs_by_epoch[epoch_row, 3] <- sum(alphs[, 1:4] > wald_thresh)    # FN
}

err_rates_by_epoch <- errs_by_epoch
err_rates_by_epoch[, 2] <- errs_by_epoch[, 2] / (100 * 200)
err_rates_by_epoch[, 3] <- errs_by_epoch[, 3] / (4 * 200)

mykable(err_rates_by_epoch[1:(rowdim/10)* 10, ], cap = "err by epoch")
err by epoch
epoch T1 T2
10000 0.00000 1.00000
20000 0.00000 0.95000
30000 0.00115 0.17625
40000 0.00035 0.09375
50000 0.00020 0.07625
60000 0.00020 0.06625
70000 0.00015 0.06000
80000 0.00015 0.05875
90000 0.00020 0.05625
100000 0.00020 0.05625
110000 0.00015 0.05500
120000 0.00015 0.05500
130000 0.00025 0.05000
140000 0.00020 0.05125
150000 0.00020 0.04875
160000 0.00020 0.04875
170000 0.00025 0.05000
180000 0.00025 0.04750
190000 0.00020 0.04750
200000 0.00020 0.05000
210000 0.00020 0.05000
220000 0.00025 0.04750
230000 0.00020 0.04875
240000 0.00020 0.04625
250000 0.00020 0.04875

Only marginal improvements when stopping after 50k epochs 150k epochs.

mykable(err_rates_by_epoch, cap = "err by epoch")
err by epoch
epoch T1 T2
1000 0.00000 1.00000
2000 0.00000 1.00000
3000 0.00000 1.00000
4000 0.00000 1.00000
5000 0.00000 1.00000
6000 0.00000 1.00000
7000 0.00000 1.00000
8000 0.00000 1.00000
9000 0.00000 1.00000
10000 0.00000 1.00000
11000 0.00000 1.00000
12000 0.00000 1.00000
13000 0.00000 1.00000
14000 0.00000 1.00000
15000 0.00000 1.00000
16000 0.00000 1.00000
17000 0.00000 1.00000
18000 0.00000 0.99750
19000 0.00000 0.98250
20000 0.00000 0.95000
21000 0.00010 0.88125
22000 0.00015 0.78625
23000 0.00030 0.67750
24000 0.00045 0.54875
25000 0.00075 0.45375
26000 0.00110 0.37625
27000 0.00120 0.30875
28000 0.00115 0.24250
29000 0.00105 0.20250
30000 0.00115 0.17625
31000 0.00090 0.15375
32000 0.00075 0.13625
33000 0.00070 0.12625
34000 0.00065 0.11375
35000 0.00065 0.10750
36000 0.00045 0.10250
37000 0.00050 0.09750
38000 0.00040 0.09625
39000 0.00035 0.09375
40000 0.00035 0.09375
41000 0.00030 0.09125
42000 0.00030 0.08875
43000 0.00025 0.08875
44000 0.00025 0.08750
45000 0.00025 0.08625
46000 0.00025 0.08625
47000 0.00020 0.08500
48000 0.00020 0.08000
49000 0.00020 0.07500
50000 0.00020 0.07625
51000 0.00020 0.07500
52000 0.00020 0.07500
53000 0.00020 0.07250
54000 0.00020 0.07250
55000 0.00020 0.07250
56000 0.00020 0.07000
57000 0.00020 0.06750
58000 0.00020 0.07000
59000 0.00020 0.06625
60000 0.00020 0.06625
61000 0.00020 0.06500
62000 0.00020 0.06625
63000 0.00020 0.06750
64000 0.00020 0.06250
65000 0.00020 0.06375
66000 0.00020 0.06500
67000 0.00020 0.06250
68000 0.00020 0.06125
69000 0.00020 0.06125
70000 0.00015 0.06000
71000 0.00015 0.06125
72000 0.00015 0.06000
73000 0.00015 0.06000
74000 0.00015 0.06000
75000 0.00015 0.06000
76000 0.00015 0.05875
77000 0.00015 0.06000
78000 0.00010 0.06000
79000 0.00015 0.05625
80000 0.00015 0.05875
81000 0.00015 0.05750
82000 0.00015 0.05750
83000 0.00015 0.05750
84000 0.00015 0.05750
85000 0.00015 0.05625
86000 0.00015 0.05625
87000 0.00015 0.05625
88000 0.00015 0.05625
89000 0.00015 0.05625
90000 0.00020 0.05625
91000 0.00015 0.05750
92000 0.00015 0.05625
93000 0.00015 0.05625
94000 0.00015 0.05750
95000 0.00015 0.05750
96000 0.00020 0.05750
97000 0.00020 0.05625
98000 0.00020 0.05500
99000 0.00015 0.05500
100000 0.00020 0.05625
101000 0.00015 0.05625
102000 0.00020 0.05625
103000 0.00020 0.05500
104000 0.00025 0.05625
105000 0.00025 0.05625
106000 0.00025 0.05375
107000 0.00025 0.05500
108000 0.00025 0.05375
109000 0.00020 0.05375
110000 0.00015 0.05500
111000 0.00025 0.05250
112000 0.00020 0.05375
113000 0.00020 0.05375
114000 0.00020 0.05125
115000 0.00015 0.05000
116000 0.00020 0.05000
117000 0.00015 0.05250
118000 0.00020 0.05375
119000 0.00025 0.05500
120000 0.00015 0.05500
121000 0.00015 0.05625
122000 0.00020 0.05500
123000 0.00020 0.05375
124000 0.00025 0.05250
125000 0.00025 0.05000
126000 0.00025 0.05250
127000 0.00025 0.05250
128000 0.00025 0.05000
129000 0.00020 0.05000
130000 0.00025 0.05000
131000 0.00025 0.05000
132000 0.00020 0.05375
133000 0.00025 0.04875
134000 0.00020 0.04875
135000 0.00020 0.04500
136000 0.00020 0.04875
137000 0.00020 0.05125
138000 0.00020 0.05250
139000 0.00020 0.05125
140000 0.00020 0.05125
141000 0.00020 0.05000
142000 0.00020 0.05125
143000 0.00020 0.05000
144000 0.00020 0.05125
145000 0.00020 0.04875
146000 0.00020 0.04875
147000 0.00025 0.05000
148000 0.00025 0.05000
149000 0.00020 0.04875
150000 0.00020 0.04875
151000 0.00020 0.05000
152000 0.00025 0.05000
153000 0.00025 0.05000
154000 0.00020 0.05000
155000 0.00020 0.04875
156000 0.00020 0.04875
157000 0.00020 0.04875
158000 0.00020 0.04875
159000 0.00020 0.05000
160000 0.00020 0.04875
161000 0.00020 0.04750
162000 0.00020 0.04750
163000 0.00020 0.04875
164000 0.00020 0.04875
165000 0.00025 0.04875
166000 0.00020 0.05000
167000 0.00025 0.05000
168000 0.00025 0.05000
169000 0.00025 0.04875
170000 0.00025 0.05000
171000 0.00025 0.05125
172000 0.00020 0.04875
173000 0.00025 0.04875
174000 0.00025 0.04875
175000 0.00025 0.04875
176000 0.00025 0.04875
177000 0.00020 0.04875
178000 0.00020 0.04875
179000 0.00025 0.04750
180000 0.00025 0.04750
181000 0.00020 0.05000
182000 0.00020 0.04875
183000 0.00020 0.05000
184000 0.00020 0.05000
185000 0.00020 0.04750
186000 0.00020 0.04750
187000 0.00025 0.04750
188000 0.00025 0.04750
189000 0.00020 0.04750
190000 0.00020 0.04750
191000 0.00020 0.04625
192000 0.00020 0.04750
193000 0.00025 0.04750
194000 0.00025 0.04875
195000 0.00025 0.05000
196000 0.00020 0.05000
197000 0.00020 0.04875
198000 0.00020 0.05000
199000 0.00020 0.05000
200000 0.00020 0.05000
201000 0.00025 0.04875
202000 0.00020 0.04875
203000 0.00020 0.04875
204000 0.00020 0.04875
205000 0.00020 0.05000
206000 0.00020 0.04875
207000 0.00020 0.05000
208000 0.00020 0.05125
209000 0.00020 0.05000
210000 0.00020 0.05000
211000 0.00020 0.05125
212000 0.00020 0.05000
213000 0.00020 0.05000
214000 0.00020 0.05000
215000 0.00025 0.04750
216000 0.00025 0.05000
217000 0.00020 0.04875
218000 0.00020 0.04875
219000 0.00020 0.04750
220000 0.00025 0.04750
221000 0.00020 0.04875
222000 0.00020 0.04875
223000 0.00020 0.04875
224000 0.00025 0.04875
225000 0.00025 0.05125
226000 0.00025 0.04875
227000 0.00020 0.04750
228000 0.00020 0.04750
229000 0.00020 0.04875
230000 0.00020 0.04875
231000 0.00020 0.04750
232000 0.00020 0.04875
233000 0.00020 0.04875
234000 0.00020 0.05000
235000 0.00020 0.04750
236000 0.00020 0.04875
237000 0.00020 0.04875
238000 0.00020 0.04875
239000 0.00020 0.04750
240000 0.00020 0.04625
241000 0.00020 0.04750
242000 0.00020 0.04750
243000 0.00020 0.04750
244000 0.00025 0.04750
245000 0.00025 0.04750
246000 0.00020 0.04875
247000 0.00020 0.04750
248000 0.00020 0.04750
249000 0.00025 0.04750
250000 0.00020 0.04875
err_df <- as.data.frame(err_rates_by_epoch)
err_df %>% 
  filter(epoch > 5E4) %>% 
  pivot_longer(
    cols = -epoch,
    names_to = "err_type",
    values_to = "rate"
  ) %>% 
  ggplot(
    aes(
      y = rate,
      x = epoch,
      color = err_type
    )
  ) + 
  geom_line() + 
  labs(title = "error rate by epoch over 200 simulations")

Note: appears to level off around 150k epochs.